{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 长短期记忆（LSTM）\n",
    "\n",
    "\n",
    "本节将介绍另一种常用的门控循环神经网络：长短期记忆（long short-term memory，LSTM）[1]。它比门控循环单元的结构稍微复杂一点。\n",
    "\n",
    "\n",
    "## 长短期记忆\n",
    "\n",
    "LSTM 中引入了3个门，即输入门（input gate）、遗忘门（forget gate）和输出门（output gate），以及与隐藏状态形状相同的记忆细胞（某些文献把记忆细胞当成一种特殊的隐藏状态），从而记录额外的信息。\n",
    "\n",
    "\n",
    "### 输入门、遗忘门和输出门\n",
    "\n",
    "与门控循环单元中的重置门和更新门一样，如图6.7所示，长短期记忆的门的输入均为当前时间步输入$\\boldsymbol{X}_t$与上一时间步隐藏状态$\\boldsymbol{H}_{t-1}$，输出由激活函数为sigmoid函数的全连接层计算得到。如此一来，这3个门元素的值域均为$[0,1]$。\n",
    "\n",
    "![长短期记忆中输入门、遗忘门和输出门的计算](../img/lstm_0.svg)\n",
    "\n",
    "具体来说，假设隐藏单元个数为$h$，给定时间步$t$的小批量输入$\\boldsymbol{X}_t \\in \\mathbb{R}^{n \\times d}$（样本数为$n$，输入个数为$d$）和上一时间步隐藏状态$\\boldsymbol{H}_{t-1} \\in \\mathbb{R}^{n \\times h}$。\n",
    "时间步$t$的输入门$\\boldsymbol{I}_t \\in \\mathbb{R}^{n \\times h}$、遗忘门$\\boldsymbol{F}_t \\in \\mathbb{R}^{n \\times h}$和输出门$\\boldsymbol{O}_t \\in \\mathbb{R}^{n \\times h}$分别计算如下：\n",
    "\n",
    "$$\n",
    "\\begin{aligned}\n",
    "\\boldsymbol{I}_t &= \\sigma(\\boldsymbol{X}_t \\boldsymbol{W}_{xi} + \\boldsymbol{H}_{t-1} \\boldsymbol{W}_{hi} + \\boldsymbol{b}_i),\\\\\n",
    "\\boldsymbol{F}_t &= \\sigma(\\boldsymbol{X}_t \\boldsymbol{W}_{xf} + \\boldsymbol{H}_{t-1} \\boldsymbol{W}_{hf} + \\boldsymbol{b}_f),\\\\\n",
    "\\boldsymbol{O}_t &= \\sigma(\\boldsymbol{X}_t \\boldsymbol{W}_{xo} + \\boldsymbol{H}_{t-1} \\boldsymbol{W}_{ho} + \\boldsymbol{b}_o),\n",
    "\\end{aligned}\n",
    "$$\n",
    "\n",
    "其中的$\\boldsymbol{W}_{xi}, \\boldsymbol{W}_{xf}, \\boldsymbol{W}_{xo} \\in \\mathbb{R}^{d \\times h}$和$\\boldsymbol{W}_{hi}, \\boldsymbol{W}_{hf}, \\boldsymbol{W}_{ho} \\in \\mathbb{R}^{h \\times h}$是权重参数，$\\boldsymbol{b}_i, \\boldsymbol{b}_f, \\boldsymbol{b}_o \\in \\mathbb{R}^{1 \\times h}$是偏差参数。\n",
    "\n",
    "\n",
    "### 候选记忆细胞\n",
    "\n",
    "接下来，长短期记忆需要计算候选记忆细胞$\\tilde{\\boldsymbol{C}}_t$。它的计算与上面介绍的3个门类似，但使用了值域在$[-1, 1]$的tanh函数作为激活函数，如图6.8所示。\n",
    "\n",
    "![长短期记忆中候选记忆细胞的计算](../img/lstm_1.svg)\n",
    "\n",
    "\n",
    "具体来说，时间步$t$的候选记忆细胞$\\tilde{\\boldsymbol{C}}_t \\in \\mathbb{R}^{n \\times h}$的计算为\n",
    "\n",
    "$$\\tilde{\\boldsymbol{C}}_t = \\text{tanh}(\\boldsymbol{X}_t \\boldsymbol{W}_{xc} + \\boldsymbol{H}_{t-1} \\boldsymbol{W}_{hc} + \\boldsymbol{b}_c),$$\n",
    "\n",
    "其中$\\boldsymbol{W}_{xc} \\in \\mathbb{R}^{d \\times h}$和$\\boldsymbol{W}_{hc} \\in \\mathbb{R}^{h \\times h}$是权重参数，$\\boldsymbol{b}_c \\in \\mathbb{R}^{1 \\times h}$是偏差参数。\n",
    "\n",
    "\n",
    "### 记忆细胞\n",
    "\n",
    "我们可以通过元素值域在$[0, 1]$的输入门、遗忘门和输出门来控制隐藏状态中信息的流动，这一般也是通过使用按元素乘法（符号为$\\odot$）来实现的。当前时间步记忆细胞$\\boldsymbol{C}_t \\in \\mathbb{R}^{n \\times h}$的计算组合了上一时间步记忆细胞和当前时间步候选记忆细胞的信息，并通过遗忘门和输入门来控制信息的流动：\n",
    "\n",
    "$$\\boldsymbol{C}_t = \\boldsymbol{F}_t \\odot \\boldsymbol{C}_{t-1} + \\boldsymbol{I}_t \\odot \\tilde{\\boldsymbol{C}}_t.$$\n",
    "\n",
    "\n",
    "如图6.9所示，遗忘门控制上一时间步的记忆细胞$\\boldsymbol{C}_{t-1}$中的信息是否传递到当前时间步，而输入门则控制当前时间步的输入$\\boldsymbol{X}_t$通过候选记忆细胞$\\tilde{\\boldsymbol{C}}_t$如何流入当前时间步的记忆细胞。如果遗忘门一直近似1且输入门一直近似0，过去的记忆细胞将一直通过时间保存并传递至当前时间步。这个设计可以应对循环神经网络中的梯度衰减问题，并更好地捕捉时间序列中时间步距离较大的依赖关系。\n",
    "\n",
    "![长短期记忆中记忆细胞的计算。这里的$\\odot$是按元素乘法](../img/lstm_2.svg)\n",
    "\n",
    "\n",
    "### 隐藏状态\n",
    "\n",
    "有了记忆细胞以后，接下来我们还可以通过输出门来控制从记忆细胞到隐藏状态$\\boldsymbol{H}_t \\in \\mathbb{R}^{n \\times h}$的信息的流动：\n",
    "\n",
    "$$\\boldsymbol{H}_t = \\boldsymbol{O}_t \\odot \\text{tanh}(\\boldsymbol{C}_t).$$\n",
    "\n",
    "这里的tanh函数确保隐藏状态元素值在-1到1之间。需要注意的是，当输出门近似1时，记忆细胞信息将传递到隐藏状态供输出层使用；当输出门近似0时，记忆细胞信息只自己保留。图6.10展示了长短期记忆中隐藏状态的计算。\n",
    "\n",
    "![长短期记忆中隐藏状态的计算。这里的$\\odot$是按元素乘法](../img/lstm_3.svg)\n",
    "\n",
    "\n",
    "## 读取数据集\n",
    "\n",
    "下面我们开始实现并展示长短期记忆。和前几节中的实验一样，这里依然使用周杰伦歌词数据集来训练模型作词。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "1"
    }
   },
   "outputs": [],
   "source": [
    "import d2lzh as d2l\n",
    "from mxnet import nd\n",
    "from mxnet.gluon import rnn\n",
    "\n",
    "(corpus_indices, char_to_idx, idx_to_char,\n",
    " vocab_size) = d2l.load_data_jay_lyrics()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 从零开始实现\n",
    "\n",
    "我们先介绍如何从零开始实现长短期记忆。\n",
    "\n",
    "### 初始化模型参数\n",
    "\n",
    "下面的代码对模型参数进行初始化。超参数`num_hiddens`定义了隐藏单元的个数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "2"
    }
   },
   "outputs": [],
   "source": [
    "num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size\n",
    "ctx = d2l.try_gpu()\n",
    "\n",
    "def get_params():\n",
    "    def _one(shape):\n",
    "        return nd.random.normal(scale=0.01, shape=shape, ctx=ctx)\n",
    "\n",
    "    def _three():\n",
    "        return (_one((num_inputs, num_hiddens)),\n",
    "                _one((num_hiddens, num_hiddens)),\n",
    "                nd.zeros(num_hiddens, ctx=ctx))\n",
    "\n",
    "    W_xi, W_hi, b_i = _three()  # 输入门参数\n",
    "    W_xf, W_hf, b_f = _three()  # 遗忘门参数\n",
    "    W_xo, W_ho, b_o = _three()  # 输出门参数\n",
    "    W_xc, W_hc, b_c = _three()  # 候选记忆细胞参数\n",
    "    # 输出层参数\n",
    "    W_hq = _one((num_hiddens, num_outputs))\n",
    "    b_q = nd.zeros(num_outputs, ctx=ctx)\n",
    "    # 附上梯度\n",
    "    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,\n",
    "              b_c, W_hq, b_q]\n",
    "    for param in params:\n",
    "        param.attach_grad()\n",
    "    return params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义模型\n",
    "\n",
    "在初始化函数中，长短期记忆的隐藏状态需要返回额外的形状为(批量大小, 隐藏单元个数)的值为0的记忆细胞。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "3"
    }
   },
   "outputs": [],
   "source": [
    "def init_lstm_state(batch_size, num_hiddens, ctx):\n",
    "    return (nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx),\n",
    "            nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面根据长短期记忆的计算表达式定义模型。需要注意的是，只有隐藏状态会传递到输出层，而记忆细胞不参与输出层的计算。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "4"
    }
   },
   "outputs": [],
   "source": [
    "def lstm(inputs, state, params):\n",
    "    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,\n",
    "     W_hq, b_q] = params\n",
    "    (H, C) = state\n",
    "    outputs = []\n",
    "    for X in inputs:\n",
    "        I = nd.sigmoid(nd.dot(X, W_xi) + nd.dot(H, W_hi) + b_i)\n",
    "        F = nd.sigmoid(nd.dot(X, W_xf) + nd.dot(H, W_hf) + b_f)\n",
    "        O = nd.sigmoid(nd.dot(X, W_xo) + nd.dot(H, W_ho) + b_o)\n",
    "        C_tilda = nd.tanh(nd.dot(X, W_xc) + nd.dot(H, W_hc) + b_c)\n",
    "        C = F * C + I * C_tilda\n",
    "        H = O * C.tanh()\n",
    "        Y = nd.dot(H, W_hq) + b_q\n",
    "        outputs.append(Y)\n",
    "    return outputs, (H, C)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练模型并创作歌词\n",
    "\n",
    "同上一节一样，我们在训练模型时只使用相邻采样。设置好超参数后，我们将训练模型并根据前缀“分开”和“不分开”分别创作长度为50个字符的一段歌词。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "5"
    }
   },
   "outputs": [],
   "source": [
    "num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2\n",
    "pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们每过40个迭代周期便根据当前训练的模型创作一段歌词。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 40, perplexity 214.436824, time 0.63 sec\n",
      " - 分开 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我\n",
      " - 不分开 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 80, perplexity 68.338125, time 0.62 sec\n",
      " - 分开 我想你这你 我不要 我不我 我不我 我不要 我不我 我不我 我不我 我不我 我不我 我不我 我不我\n",
      " - 不分开 我想你这你 我不要 我不要 我不要 我不要 我不要 我不不 我不不 我不不 我不不 我不不 我不不\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 120, perplexity 17.466764, time 0.63 sec\n",
      " - 分开 我想你你 你子我 一话是 是怎么 一直怎 我的上空 是一段 停诉了 一片段 停片了 装片了 装片了\n",
      " - 不分开 你知你的爱笑 像知  想想你的太笑 像通  又给了我想你 你发  我想了你了我有 我想 你想很久久\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 160, perplexity 4.222124, time 0.63 sec\n",
      " - 分开 你是是 是子她睛  有话 对念的脚手 干什么 干什么 我满了那信的手 干什么 干什么 什么我的手开\n",
      " - 不分开 你已经没你单大着我的你 你 我想我爸的脑膀膀家题 你已说说 我不能再想 我不要再我 我不 我不 我\n"
     ]
    }
   ],
   "source": [
    "d2l.train_and_predict_rnn(lstm, get_params, init_lstm_state, num_hiddens,\n",
    "                          vocab_size, ctx, corpus_indices, idx_to_char,\n",
    "                          char_to_idx, False, num_epochs, num_steps, lr,\n",
    "                          clipping_theta, batch_size, pred_period, pred_len,\n",
    "                          prefixes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 简洁实现\n",
    "\n",
    "在Gluon中我们可以直接调用`rnn`模块中的`LSTM`类。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "6"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 40, perplexity 216.582305, time 0.05 sec\n",
      " - 分开 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我\n",
      " - 不分开 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 80, perplexity 64.757303, time 0.05 sec\n",
      " - 分开 你不你的爱爱 我想你你你你你 我想你 你你我 别不了 我不要 我不要 我不了 我不了 我不了 我不\n",
      " - 不分开 我想你的爱爱 我想要你 我想要这不 我想要你 我不要 我不要 我不要 我不要 我不了 我不了 我不\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 120, perplexity 13.708151, time 0.05 sec\n",
      " - 分开 你小么 一颗我的在江南 我牵能背 干两两步 在子村外 在一村碗的粥边 配上的客柳 你成的在娘人 一\n",
      " - 不分开 我知儿人不起 不知不觉 快使了人落棍 哼哼哈兮 快使用双截棍 哼哼哈兮 快使用双截棍 哼哼哈兮 快\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 160, perplexity 3.625087, time 0.05 sec\n",
      " - 分开 别小么 是你一睛在江南等我 泪成休 语沉默 娘子她依旧每日折一枝杨柳 在小村外的溪边河口 默默的在\n",
      " - 不分开 我想道你不经 不知不觉 又过了一个秋 后知后觉 我该好好生活 我该好好生活 不知不觉 你已经离开我\n"
     ]
    }
   ],
   "source": [
    "lstm_layer = rnn.LSTM(num_hiddens)\n",
    "model = d2l.RNNModel(lstm_layer, vocab_size)\n",
    "d2l.train_and_predict_rnn_gluon(model, num_hiddens, vocab_size, ctx,\n",
    "                                corpus_indices, idx_to_char, char_to_idx,\n",
    "                                num_epochs, num_steps, lr, clipping_theta,\n",
    "                                batch_size, pred_period, pred_len, prefixes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* 长短期记忆的隐藏层输出包括隐藏状态和记忆细胞。只有隐藏状态会传递到输出层。\n",
    "* 长短期记忆的输入门、遗忘门和输出门可以控制信息的流动。\n",
    "* 长短期记忆可以应对循环神经网络中的梯度衰减问题，并更好地捕捉时间序列中时间步距离较大的依赖关系。\n",
    "\n",
    "\n",
    "\n",
    "## 练习\n",
    "\n",
    "* 调节超参数，观察并分析对运行时间、困惑度以及创作歌词的结果造成的影响。\n",
    "* 在相同条件下，比较长短期记忆、门控循环单元和不带门控的循环神经网络的运行时间。\n",
    "* 既然候选记忆细胞已通过使用tanh函数确保值域在-1到1之间，为什么隐藏状态还需要再次使用tanh函数来确保输出值域在-1到1之间？\n",
    "\n",
    "\n",
    "## 参考文献\n",
    "\n",
    "[1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9(8), 1735-1780.\n",
    "\n",
    "## 扫码直达[讨论区](https://discuss.gluon.ai/t/topic/4049)\n",
    "\n",
    "![](../img/qr_lstm.svg)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}